{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Rank Classification using BERT on Amazon Review\n",
    "\n",
    "## Introduction\n",
    "\n",
    "In this tutorial, you learn how to use a pre-trained Tensorflow model to classifiy a Amazon Review rank. The model was refined on Amazon Review dataset with a pretrained DistilBert model.\n",
    "\n",
    "### About the dataset and model\n",
    "\n",
    "[Amazon Customer Review dataset](https://s3.amazonaws.com/amazon-reviews-pds/readme.html) consists of all different valid reviews from amazon.com. We will use the \"Digital_software\" category that consists of 102k valid reviews. As for the pre-trained model, use the DistilBERT[[1]](https://arxiv.org/abs/1910.01108) model. It's a light-weight BERT model already trained on [Wikipedia text corpora](https://en.wikipedia.org/wiki/List_of_text_corpora), a much larger dataset consisting of over millions text. The DistilBERT served as a base layer and we will add some more classification layers to output as rankings (1 - 5).\n",
    "\n",
    "<img src=\"https://djl-ai.s3.amazonaws.com/resources/images/amazon_review.png\" width=\"500\">\n",
    "<center>Amazon Review example</center>\n",
    "\n",
    "\n",
    "## Pre-requisites\n",
    "This tutorial assumes you have the following knowledge. Follow the READMEs and tutorials if you are not familiar with:\n",
    "1. How to setup and run [Java Kernel in Jupyter Notebook](https://github.com/deepjavalibrary/djl/blob/master/jupyter/README.md)\n",
    "2. Basic components of Deep Java Library, and how to [train your first model](https://github.com/deepjavalibrary/djl/blob/master/jupyter/tutorial/02_train_your_first_model.ipynb).\n",
    "\n",
    "\n",
    "## Getting started\n",
    "Load the Deep Java Libarary and its dependencies from Maven. In here, you can choose between MXNet or PyTorch. MXNet is enabled by default. You can uncomment PyTorch dependencies and comment MXNet ones to switch to PyTorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
    "\n",
    "%maven ai.djl:api:0.12.0\n",
    "%maven org.slf4j:slf4j-api:1.7.26\n",
    "%maven org.slf4j:slf4j-simple:1.7.26\n",
    "        \n",
    "%maven ai.djl.tensorflow:tensorflow-engine:0.12.0\n",
    "%maven ai.djl.tensorflow:tensorflow-api:0.12.0\n",
    "%maven org.bytedeco:javacpp:1.5.4\n",
    "%maven ai.djl.tensorflow:tensorflow-native-auto:2.4.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%loadFromPOM\n",
    "<dependency>\n",
    "    <groupId>com.google.protobuf</groupId>\n",
    "    <artifactId>protobuf-java</artifactId>\n",
    "    <version>3.8.0</version>\n",
    "</dependency>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's import the necessary modules:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.*;\n",
    "import ai.djl.engine.*;\n",
    "import ai.djl.inference.*;\n",
    "import ai.djl.modality.*;\n",
    "import ai.djl.modality.nlp.*;\n",
    "import ai.djl.modality.nlp.bert.*;\n",
    "import ai.djl.ndarray.*;\n",
    "import ai.djl.repository.zoo.*;\n",
    "import ai.djl.translate.*;\n",
    "import ai.djl.training.util.*;\n",
    "import ai.djl.util.*;\n",
    "\n",
    "import java.io.*;\n",
    "import java.nio.file.*;\n",
    "import java.util.*;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare your model files\n",
    "\n",
    "You can download pre-trained Tensorflow model from: https://resources.djl.ai/demo/tensorflow/amazon_review_rank_classification.zip."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String modelUrl = \"https://resources.djl.ai/demo/tensorflow/amazon_review_rank_classification.zip\";\n",
    "DownloadUtils.download(modelUrl, \"build/amazon_review_rank_classification.zip\", new ProgressBar());\n",
    "Path zipFile = Paths.get(\"build/amazon_review_rank_classification.zip\");\n",
    "\n",
    "Path modelDir = Paths.get(\"build/saved_model\");\n",
    "if (Files.notExists(modelDir)) {\n",
    "    ZipUtils.unzip(Files.newInputStream(zipFile), modelDir);    \n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Translator\n",
    "\n",
    "Inference in deep learning is the process of predicting the output for a given input based on a pre-defined model.\n",
    "DJL abstracts away the whole process for ease of use. It can load the model, perform inference on the input, and provide output.\n",
    "\n",
    "The `Translator` interface is used to: Pre-processing and Post-processing. The pre-processing\n",
    "component converts the user-defined input objects into an NDList, so that the `Predictor` in DJL can understand the\n",
    "input and make its prediction. Similarly, the post-processing block receives an NDList as the output from the\n",
    "`Predictor`. The post-processing block allows you to convert the output from the `Predictor` to the desired output\n",
    "format.\n",
    "\n",
    "### Pre-processing\n",
    "\n",
    "Now, you need to convert the sentences into tokens. We provide a powerful tool `BertTokenizer` that you can use to convert questions and answers into tokens, and batchify your sequence together. Once you have properly formatted tokens, you can use `Vocabulary` to map your token to BERT index.\n",
    "\n",
    "The following code block demonstrates tokenizing the question and answer defined earlier into BERT-formatted tokens.\n",
    "\n",
    "In the zip file, we also bundled the BERT `vocab.txt` file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Prepare the vocabulary\n",
    "Path vocabFile = modelDir.resolve(\"vocab.txt\");\n",
    "SimpleVocabulary vocabulary = SimpleVocabulary.builder()\n",
    "        .optMinFrequency(1)\n",
    "        .addFromTextFile(vocabFile)\n",
    "        .optUnknownToken(\"[UNK]\")\n",
    "        .build();\n",
    "BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);\n",
    "int maxTokenLength = 64; // cutoff tokens length\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyTranslator implements Translator<String, Classifications> {\n",
    "\n",
    "    private BertFullTokenizer tokenizer;\n",
    "    private SimpleVocabulary vocab;\n",
    "    private List<String> ranks;\n",
    "    private int length;\n",
    "\n",
    "    public MyTranslator(BertFullTokenizer tokenizer, int length) {\n",
    "        this.tokenizer = tokenizer;\n",
    "        this.length = length;\n",
    "        vocab = tokenizer.getVocabulary();\n",
    "        ranks = Arrays.asList(\"1\", \"2\", \"3\", \"4\", \"5\");\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public Batchifier getBatchifier() {\n",
    "        return new StackBatchifier();\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public NDList processInput(TranslatorContext ctx, String input) {\n",
    "        List<String> tokens = tokenizer.tokenize(input);\n",
    "        long[] indices = new long[length];\n",
    "        long[] mask = new long[length];\n",
    "        long[] segmentIds = new long[length];\n",
    "        int size = Math.min(length, tokens.size());\n",
    "        for (int i = 0; i < size; i++) {\n",
    "            indices[i + 1] = vocab.getIndex(tokens.get(i));\n",
    "        }\n",
    "        Arrays.fill(mask,  0, size, 1);\n",
    "        NDManager m = ctx.getNDManager();\n",
    "        return new NDList(m.create(indices), m.create(mask), m.create(segmentIds));\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public Classifications processOutput(TranslatorContext ctx, NDList list) {\n",
    "        return new Classifications(ranks, list.singletonOrThrow().softmax(0));\n",
    "    }\n",
    "}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load your model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MyTranslator translator = new MyTranslator(tokenizer, maxTokenLength);\n",
    "\n",
    "Criteria<String, Classifications> criteria = Criteria.builder()\n",
    "        .setTypes(String.class, Classifications.class)\n",
    "        .optModelPath(modelDir) // Load model form model directory\n",
    "        .optTranslator(translator) // use custom translaotr \n",
    "        .build();\n",
    "\n",
    "ZooModel<String, Classifications> model = criteria.loadModel();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run inference\n",
    "\n",
    "Lastly, we will need to create a predictor using our model and translator. Once we have a predictor, we simply need to call the predict method on our test image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String review = \"It works great, but it takes too long to update itself and slows the system\";\n",
    "\n",
    "Predictor<String, Classifications> predictor = model.newPredictor();\n",
    "Classifications classifications = predictor.predict(review);\n",
    "\n",
    "classifications"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
